{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Handwritten Image Classification with Vision in Transformers (ViT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import Dataset, random_split, DataLoader\n",
    "\n",
    "from torchvision import datasets, transforms, models\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import Dataset, random_split, DataLoader\n",
    "from torchvision.utils import save_image\n",
    "\n",
    "from torchsummary import summary\n",
    "\n",
    "import spacy\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import os\n",
    "import time\n",
    "import math\n",
    "from PIL import Image\n",
    "import glob\n",
    "from IPython.display import display"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "np.random.seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 64\n",
    "LR = 5e-5\n",
    "NUM_EPOCHES = 25"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean, std = (0.5,), (0.5,)\n",
    "\n",
    "transform = transforms.Compose([transforms.ToTensor(),\n",
    "                                transforms.Normalize(mean, std)\n",
    "                              ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainset = datasets.MNIST('../data/MNIST/', download=True, train=True, transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)\n",
    "\n",
    "testset = datasets.MNIST('../data/MNIST/', download=True, train=False, transform=transform)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformer_package.models import ViT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ViT(\n",
       "  (dropout_layer): Dropout(p=0.2, inplace=False)\n",
       "  (embeddings): Linear(in_features=49, out_features=512, bias=True)\n",
       "  (encoders): ModuleList(\n",
       "    (0): VisionEncoder(\n",
       "      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "      (attention): MultiHeadAttention(\n",
       "        (dropout_layer): Dropout(p=0.2, inplace=False)\n",
       "        (Q): Linear(in_features=512, out_features=512, bias=True)\n",
       "        (K): Linear(in_features=512, out_features=512, bias=True)\n",
       "        (V): Linear(in_features=512, out_features=512, bias=True)\n",
       "        (linear): Linear(in_features=512, out_features=512, bias=True)\n",
       "      )\n",
       "      (mlp): Sequential(\n",
       "        (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "        (1): GELU()\n",
       "        (2): Dropout(p=0.2, inplace=False)\n",
       "        (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "        (4): Dropout(p=0.2, inplace=False)\n",
       "      )\n",
       "    )\n",
       "    (1): VisionEncoder(\n",
       "      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "      (attention): MultiHeadAttention(\n",
       "        (dropout_layer): Dropout(p=0.2, inplace=False)\n",
       "        (Q): Linear(in_features=512, out_features=512, bias=True)\n",
       "        (K): Linear(in_features=512, out_features=512, bias=True)\n",
       "        (V): Linear(in_features=512, out_features=512, bias=True)\n",
       "        (linear): Linear(in_features=512, out_features=512, bias=True)\n",
       "      )\n",
       "      (mlp): Sequential(\n",
       "        (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "        (1): GELU()\n",
       "        (2): Dropout(p=0.2, inplace=False)\n",
       "        (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "        (4): Dropout(p=0.2, inplace=False)\n",
       "      )\n",
       "    )\n",
       "    (2): VisionEncoder(\n",
       "      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "      (attention): MultiHeadAttention(\n",
       "        (dropout_layer): Dropout(p=0.2, inplace=False)\n",
       "        (Q): Linear(in_features=512, out_features=512, bias=True)\n",
       "        (K): Linear(in_features=512, out_features=512, bias=True)\n",
       "        (V): Linear(in_features=512, out_features=512, bias=True)\n",
       "        (linear): Linear(in_features=512, out_features=512, bias=True)\n",
       "      )\n",
       "      (mlp): Sequential(\n",
       "        (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "        (1): GELU()\n",
       "        (2): Dropout(p=0.2, inplace=False)\n",
       "        (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "        (4): Dropout(p=0.2, inplace=False)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "  (classifier): Sequential(\n",
       "    (0): Linear(in_features=512, out_features=10, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "image_size = 28\n",
    "channel_size = 1\n",
    "patch_size = 7\n",
    "embed_size = 512\n",
    "num_heads = 8\n",
    "classes = 10\n",
    "num_layers = 3\n",
    "hidden_size = 256\n",
    "dropout = 0.2\n",
    "\n",
    "model = ViT(image_size, channel_size, patch_size, embed_size, num_heads, classes, num_layers, hidden_size, dropout=dropout).to(device)\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input Image Dimensions: torch.Size([64, 1, 28, 28])\n",
      "Label Dimensions: torch.Size([64])\n",
      "----------------------------------------------------------------------------------------------------\n",
      "Output Dimensions: torch.Size([64, 10])\n"
     ]
    }
   ],
   "source": [
    "for img, label in trainloader:\n",
    "    img = img.to(device)\n",
    "    label = label.to(device)\n",
    "    \n",
    "    print(\"Input Image Dimensions: {}\".format(img.size()))\n",
    "    print(\"Label Dimensions: {}\".format(label.size()))\n",
    "    print(\"-\"*100)\n",
    "    \n",
    "    out = model(img)\n",
    "    \n",
    "    print(\"Output Dimensions: {}\".format(out.size()))\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.NLLLoss()\n",
    "optimizer = torch.optim.Adam(params=model.parameters(), lr=LR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-------------------------------------------------\n",
      "Epoch: 1 Train mean loss: 640.67601412\n",
      "       Train Accuracy%:  77.06833333333333 == 46241 / 60000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 2 Train mean loss: 228.43071219\n",
      "       Train Accuracy%:  92.44 == 55464 / 60000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 3 Train mean loss: 174.51129506\n",
      "       Train Accuracy%:  94.15166666666667 == 56491 / 60000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 4 Train mean loss: 143.14635622\n",
      "       Train Accuracy%:  95.10333333333334 == 57062 / 60000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 5 Train mean loss: 127.15741576\n",
      "       Train Accuracy%:  95.685 == 57411 / 60000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 6 Train mean loss: 113.60460257\n",
      "       Train Accuracy%:  96.085 == 57651 / 60000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 7 Train mean loss: 100.84070168\n",
      "       Train Accuracy%:  96.55666666666667 == 57934 / 60000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 8 Train mean loss: 93.53336407\n",
      "       Train Accuracy%:  96.83833333333334 == 58103 / 60000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 9 Train mean loss: 83.18391872\n",
      "       Train Accuracy%:  97.12166666666667 == 58273 / 60000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 10 Train mean loss: 77.65446205\n",
      "       Train Accuracy%:  97.23833333333333 == 58343 / 60000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 11 Train mean loss: 71.72018375\n",
      "       Train Accuracy%:  97.525 == 58515 / 60000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 12 Train mean loss: 66.29108043\n",
      "       Train Accuracy%:  97.67166666666667 == 58603 / 60000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 13 Train mean loss: 61.66031387\n",
      "       Train Accuracy%:  97.77333333333333 == 58664 / 60000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 14 Train mean loss: 59.20973012\n",
      "       Train Accuracy%:  97.945 == 58767 / 60000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 15 Train mean loss: 54.94569581\n",
      "       Train Accuracy%:  98.00833333333334 == 58805 / 60000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 16 Train mean loss: 52.12647567\n",
      "       Train Accuracy%:  98.15 == 58890 / 60000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 17 Train mean loss: 46.19083659\n",
      "       Train Accuracy%:  98.34333333333333 == 59006 / 60000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 18 Train mean loss: 45.18581017\n",
      "       Train Accuracy%:  98.37 == 59022 / 60000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 19 Train mean loss: 41.96202690\n",
      "       Train Accuracy%:  98.52333333333333 == 59114 / 60000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 20 Train mean loss: 41.38425341\n",
      "       Train Accuracy%:  98.57666666666667 == 59146 / 60000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 21 Train mean loss: 38.17168036\n",
      "       Train Accuracy%:  98.63 == 59178 / 60000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 22 Train mean loss: 35.98693700\n",
      "       Train Accuracy%:  98.72666666666667 == 59236 / 60000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 23 Train mean loss: 34.32414370\n",
      "       Train Accuracy%:  98.775 == 59265 / 60000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 24 Train mean loss: 33.43680637\n",
      "       Train Accuracy%:  98.75 == 59250 / 60000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 25 Train mean loss: 31.97760608\n",
      "       Train Accuracy%:  98.845 == 59307 / 60000\n",
      "-------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "loss_hist = {}\n",
    "loss_hist[\"train accuracy\"] = []\n",
    "loss_hist[\"train loss\"] = []\n",
    "\n",
    "for epoch in range(1, NUM_EPOCHES+1):\n",
    "    model.train()\n",
    "    \n",
    "    epoch_train_loss = 0\n",
    "        \n",
    "    y_true_train = []\n",
    "    y_pred_train = []\n",
    "        \n",
    "    for batch_idx, (img, labels) in enumerate(trainloader):\n",
    "        img = img.to(device)\n",
    "        labels = labels.to(device)\n",
    "        \n",
    "        preds = model(img)\n",
    "        \n",
    "        loss = criterion(preds, labels)\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        y_pred_train.extend(preds.detach().argmax(dim=-1).tolist())\n",
    "        y_true_train.extend(labels.detach().tolist())\n",
    "            \n",
    "        epoch_train_loss += loss.item()\n",
    "    \n",
    "    loss_hist[\"train loss\"].append(epoch_train_loss)\n",
    "    \n",
    "    total_correct = len([True for x, y in zip(y_pred_train, y_true_train) if x==y])\n",
    "    total = len(y_pred_train)\n",
    "    accuracy = total_correct * 100 / total\n",
    "    \n",
    "    loss_hist[\"train accuracy\"].append(accuracy)\n",
    "    \n",
    "    print(\"-------------------------------------------------\")\n",
    "    print(\"Epoch: {} Train mean loss: {:.8f}\".format(epoch, epoch_train_loss))\n",
    "    print(\"       Train Accuracy%: \", accuracy, \"==\", total_correct, \"/\", total)\n",
    "    print(\"-------------------------------------------------\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAagUlEQVR4nO3de3Sc9X3n8fdX99FdlmRblpCNAySAjYG4DgmXJhvSTdlsgKQ0pCFlUwpsmpyEkG2bbU8Xtml2k2xSctJuewoNLWRDGnIrKWR7IN4Ekk0gERSMsdeYgu3YlmXJF0mjy9z03T/mkSXbuvnyzFjz+7zO0ZmZZ2ak7+M5/n2e5/c8z3fM3RERkXCVFbsAEREpLgWBiEjgFAQiIoFTEIiIBE5BICISuIpiF7AQbW1tvmrVqmKXISKyqDz77LMD7t4+3+sWRRCsWrWKnp6eYpchIrKomNnOhbxOU0MiIoFTEIiIBE5BICISOAWBiEjgFAQiIoFTEIiIBE5BICISuEVxHYGISKkaz+QYHMswOJbh8Ghm2v00Q2MZ3vvGLla21sVag4JARIKRzk5EA236yKB7eDTD0HiGsUyO8cwEqUyO8ej+eHbqfiobLcvkSGUnyE5MUFFWRkWZUV5mVJaXUVFuVJRZfvnk/fKyI7dj6RyDY+mjBv1UdmLWes3gkpUtCgIRCVM2N8F4doKx9ORgnGMsk2MsnTsyaE9fNp7NMR49NzSWzQ+20YA/NJbh8FiG0XRu3r9bXVFGTWU5NZXRbUX+fnVlOQ01FbQ3VFNTWU5FmZGdcHITE2RyTjY3QXbCyeac3IQzns2Rm/CjnqupLKc5UcnqtnqaEpU011bSmKg8cr8pUUlzooqmaFlDTQVlZRb7v7WCQERikUxl2Tc4xt7D4+wbHKd3cJz+5Dij0cA+OaCPRVvh0wf5VGaCdG72LeW51FSW0ViTH1ibE1V0tdTS3FlJ8/QBt7bqqMeNNZUkqsqprijDLP6B90yjIBCRo7hHW7HTtnRzE05mIn//yHNZZ2AkdWSQ3zc4Ft3mf4ZT2eN+95K6KmqryklUlpOoym9tNyUqSTTmt7ITleX52+i5RFXZ0csnn4u22CcfTz4f6kB+qhQEIgEZHs+w88Aouw6OsuPACLsOTN0eGElHUx0n/j3mZtBeX01Hc4LXtddz+TltdDTVsLypho6mBB1NNSxtrKa6ojyGtZJTpSAQKSGZ3AT9wyl6B8fZdXCEHQP5QX/ngRF2RoP9dG31VXQvqeWy1a20NVRTWZ4/0FlZbpRHt5MHOiefqyiPDoyWGa31VSxvSrC0oZrKcp2NvlgpCEQWgdyEcyCZom8oRd/QOH3D4/QNpdg/NJ5/PJRi//A4B0bS+LQNejNY0ZSge0ktv3bhMrqX1LGqtZbu1lpWttZRX60hQBQEIgUxls6x5/AYuw+NcnAkzWg6x2g6y0gqf3B0JJVlLJ1jJJ2NnouWRc8dHElz7IyNGbTVV7OssZqOphrWndXMssZqljXWsKyxmu4ldZy1JKHpGJmXgkDkNEimsuw5lB/o8wP+2FGPB5LpWd9bU1lGXVUFtdXl1FZGt1XlLKmrpa6qnERVBW31VSxtrGFZw+RAX0NbfRUVmo6R00BBIHKMiQlneDzL4bE0h0YzHBpNMxjdHh7NX/F5aDR/XvqBZIo9h8c4PJo56ndUVZTR1ZygsyXBBSsa6WqppbM5QVdLgtb6auqqy6mtqiBRWU55Ac4TF5mLgkCCkc1N0J/Mn+7YNxSd5hjNs++L5toPjuSv+pztxBkzps5Rr61iWWMNl3Q3HzXQd7YkaKurLsiFQCKng4JASspoOsvmPUNs2n2YHQdG2DcYHVwdGmcgmTpugK8sN5Y25E9zfP3yBpbUVdFSW0VzdMFRS10lTYkqWmoraamtojFRqS14KTkKAlm0UtkcW3uH2bT7MJt2D7Jp92Fe2Z88Mtg3JSrpaMrPp1/Q0Zg/kNpUw/Jojn15Uw1Laqu05S7BUxDIopDJTbC9L5kf9PfkB/1t+4bJ5PKjfmtdFRd1NfHrazq4qKuJtV1NLG2oKXLVIouDgkCKJpXNcXAkzYFkOn87kpq6n0xzIFp2cCTNvsHxI10aG2squKirmd+9cjXruppY29XMiqYatRYQOUkKAond8HiGl/YO8eLuQTbtGWTL3kH2D6Vm7EUDHLlidUldNa11+Stfl11Yw4UrGlnX1czK1loN+iKnkYJATqvRdJaX9g6xafcgL0bTOK8NjBy52rWzOcGazkauOq+d1rposK+viu5X0VpfTWNNhQZ6kQJSEMhJG0vn2NI7xOY9g/mBf8/RB2uXN9awtquJ6y/uZG1XE2s7m2itry5u0SJyHAWBLMh4JsfWowb9QbbvTx7pVNlWX8VFXc1TB2s7m1jaqIO1IouBgkCOk8rm2LZvOJreyQ/6L/cNk40G/SV1VaztbOLq85cd2dLv0MFakUVLQSC4O1t6h3hsUy9Pbe8/6rTM5tpK1nY2cdvrV0enZeoMHZFSoyAI2LZ9wzy6aS+Pberl1YERysuMX1nVwi1XrD4yvdPVktCgL1LiFASBeWX/MI9u6uWxTb1s35+kzOCy1a387pWr+bcXLtPBXJEAKQgC8NrACI++sJfHXuzl/+0bxgw2rFrCp69bwzsvXE57gwZ/kZApCEqQu7Otb5iNW/fz2KZetvQOAfArq1q4+99fwDVrO3RGj4gcoSAoEcPjGf7vKwP8aFs/T77cT+/gOACXdjfzJ++6gGvWLqejKVHkKkXkTKQgWKQmt/p/tK2fH23bT8+OQ2QnnIbqCq48r41PnLeUX319O8u05S8i81AQLCLJVJafbB/gyZf386NtU1v9b1jewK1Xreat57Vz6coWKvX1hSJyAhQEZ7jchLNxax9fe2YXP/3XATI5p766givOaeOOq9v51fOWsrxJW/0icvIUBGeogWSKb/zilzz0zC72HB5jeWMNv3P52bztDUt5o7b6ReQ0UhCcQdyd53Yd4sGf7eT7L/aSyTmXn9PKn7zrfK4+fxkVGvxFJAYKgjPAaDrLI8/v5as/28mW3iEaqiv4wJtWctNlKzlnaX2xyxOREhdrEJjZx4FbAQPuc/cvmdnd0bL+6GV/5O7fj7OOM9Wr/Um++vROvvXsbobHs7xheQOfuX4N113cSV21MlpECiO20cbM1pAf8DcAaeCfzeyx6Ol73P0Lcf3tM92zOw/xpR+8zI+3D1BZbrxzTQe//eaVrF/Zor4+IlJwcW52ng887e6jAGb2JHB9jH/vjDeeyXHPD17mvqdepb2hmk++4zzet+Esfcm6iBRVnEGwGfiMmbUCY8A1QA9wAPiomf129PiT7n7o2Deb2W3AbQDd3d0xllkYm/cMcufDz/NyX5L3bziLP/53F1Cv6R8ROQOYT36ZbBy/3OwW4CNAEthCPhA+CwwADnwa6HD335nr96xfv957enpiqzNOmdwEf/XDf+Uv/s92ltRV8bnfuIi3vX5pscsSkQCY2bPuvn6+18W6SeruXwG+EhX034Dd7t43+byZ3Qc8GmcNxbS9b5hPfvMFNu0e5LqLV3D3uy+kubaq2GWJiBwl7rOGlrr7fjPrBt4DvNnMOty9N3rJ9eSnkEpKbsK5/yev8T8e30ZdVTl/9YFLuWZtR7HLEhGZUdyT1N+OjhFkgI+4+yEz+6qZXUx+amgHcHvMNRTUrgOj/KdvvsDPdxzk6vOX8d/fs1b9/kXkjBb31NCVMyz7YJx/s1jcnYd+vovPPLaVcjO+eMM63nNpp04HFZEznk5bOQ16B8f4g29t4sfbB7jinDY+/xsXsaJZvf9FZHFQEJyizXsG+a37niaTcz593RpuelO39gJEZFFREJyC/cPj3PpgD/XVFTx062WsaqsrdkkiIidMQXCSUtkcH/5fz3FoNM23P/wWhYCILFoKgpPg7vyXf3yJZ3ce4i9/6xIuXNFU7JJERE6aGtyfhAd/tpNv9PySj77tHN510YpilyMickoUBCfop68M8KePbuHq85dx5zvOK3Y5IiKnTEFwAnYdGOX3HnqO1W113PO+dZSV6ewgEVn8FAQLlExlufXBHtzhb29eT0NNZbFLEhE5LXSweAEmJpw7v/E8r/QneeBDG1jZqjOERKR0aI9gAb60cTuPb+njj685nyvObSt2OSIip5WCYB7/+8VevrxxOze8sYsPXb6q2OWIiJx2CoI5bNk7xJ0Pv8Al3c382fVr1DpCREqSgmAWB5Ipbn2wh6ZEJX9z0xuprigvdkkiIrHQweIZZHIT/N7XnqM/meKbt7+ZpY36cnkRKV3aI5jBn/7TFp557SCff+9FrDurudjliIjESkFwjIee2cVXn97J7Vet5rpLOotdjohI7BQEx7jnBy/zprOX8AfvfEOxSxERKQgFwTSZ3AQDyRSXrW6lXO0jRCQQCoJpDo6kcYc2fdm8iAREQTBN/3AKgPZ6BYGIhENBMM1AMgqChqoiVyIiUjgKgmkGkmkA2rRHICIBURBMMzk1pCAQkZAoCKYZSKaorSqnrloXXItIOBQE0wwkU9obEJHgKAim6R9O0VavA8UiEhYFwTQDyRTtuoZARAKjIJhmIJnW1JCIBEdBEMnkJjg0qiAQkfAoCCJqLyEioVIQRNReQkRCpSCIqL2EiIRKQRBRewkRCZWCIKL2EiISKgVBRO0lRCRUsQaBmX3czDab2Utmdke0bImZPWFm26PbljhrWCi1lxCRUMUWBGa2BrgV2ACsA95lZucCnwI2uvu5wMbocdGpvYSIhCrOPYLzgafdfdTds8CTwPXAtcAD0WseAK6LsYYF0x6BiIQqziDYDFxlZq1mVgtcA5wFLHP3XoDodmmMNSzYQDKtPkMiEqTYjoy6+1Yz+xzwBJAEXgCyC32/md0G3AbQ3d0dS42T1F5CREIW68Fid/+Ku1/q7lcBB4HtQJ+ZdQBEt/tnee+97r7e3de3t7fHWabaS4hI0OI+a2hpdNsNvAf4OvA94OboJTcDj8RZw0KovYSIhCzuk+a/bWatQAb4iLsfMrPPAg+b2S3ALuCGmGuYl9pLiEjIYg0Cd79yhmUHgLfH+XdPlNpLiEjIdGUxai8hImFTEKD2EiISNgUBuphMRMKmIGAyCHSgWETCpCBgss+Q9ghEJEwLCgIzqzOzsuj+eWb2bjOrjLe0wlF7CREJ2UL3CJ4Casysk3zH0A8Bfx9XUYWk9hIiErqFBoG5+yj5q4P/wt2vBy6Ir6zCUXsJEQndgoPAzN4MfAB4LFpWEudaTrWX0MFiEQnTQoPgDuA/A99195fMbDXww9iqKqCp9hLaIxCRMC1oq97dnyT/xTJEB40H3P1jcRZWKGovISKhW+hZQw+ZWaOZ1QFbgG1m9vvxllYYai8hIqFb6NTQBe4+RP5rJb8PdAMfjKuoQhpIpkhUqr2EiIRroUFQGV03cB3wiLtnAI+tqgIaSKZ0fEBEgrbQIPgbYAdQBzxlZiuBobiKKiS1lxCR0C0oCNz9y+7e6e7XeN5O4G0x11YQai8hIqFb6MHiJjP7czPriX6+SH7vYNFTewkRCd1Cp4buB4aB34x+hoC/i6uoQsmqvYSIyIKvDn6du7932uP/ambPx1BPQam9hIjIwvcIxszsiskHZnY5MBZPSYWzX+0lREQWvEfwH4EHzawpenwIuDmekgpH7SVERBbeYuIFYJ2ZNUaPh8zsDmBTjLXFTu0lRERO8BvK3H0ousIY4M4Y6ikotZcQETm1r6q001ZFkai9hIjIqQXBom8xofYSIiLzHCMws2FmHvANSMRSUQGpvYSIyDxB4O4NhSqkGPqHU6xqLYkLpEVETtqpTA0tegPJtC4mE5HgBRsEk+0l2nXGkIgELtggUHsJEZG8YINA7SVERPKCDYLJ9hK6mExEQhdwEOTbS+g6AhEJXcBBoD0CEREIOAj6h9VeQkQEAg4CtZcQEckLOgjUXkJEJOYgMLNPmNlLZrbZzL5uZjVmdreZ7TGz56Ofa+KsYTb9wykdHxARIcYgMLNO4GPAendfA5QDN0ZP3+PuF0c/34+rhrmovYSISF7cU0MVQMLMKoBaYG/Mf29B1F5CRGRKbEHg7nuALwC7gF5g0N0fj57+qJltMrP7zaxlpveb2W1m1mNmPf39/ae1NrWXEBGZEufUUAtwLXA2sAKoM7ObgL8GXgdcTD4gvjjT+939Xndf7+7r29vbT2ttai8hIjIlzqmhq4HX3L3f3TPAd4C3uHufu+fcfQK4D9gQYw0z0sVkIiJT4gyCXcBlZlZrZga8HdhqZh3TXnM9sDnGGmak9hIiIlNiu6zW3Z8xs28BzwFZ4F+Ae4G/NbOLyX8F5g7g9rhqmI32CEREpsTaX8Hd7wLuOmbxB+P8mwuh9hIiIlOCvLJ4IJmirUEHikVEIOAg0DUEIiJ5QQaB2kuIiEwJMgjUXkJEZEpwQTDZXkJ7BCIiecEFwWR7CV1DICKSF1wQqL2EiMjRggsCXUwmInK0AINA7SVERKYLMAi0RyAiMl1wQaD2EiIiRwsuCNReQkTkaEEGgdpLiIhMCS8IhnUxmYjIdMEFQX8ypfYSIiLTBBUEai8hInK8oIJA7SVERI4XVBCovYSIyPGCCgJdTCYicrzAgiDfXkJBICIyJbAgiKaGdIxAROSIoIJA7SVERI4XVBCovYSIyPHCCwIdHxAROUpYQTCcVp8hEZFjBBUEai8hInK8YIJA7SVERGYWTBCovYSIyMyCCYL+pNpLiIjMJJwgGFZ7CRGRmQQTBGovISIys4CCQO0lRERmEkwQqL2EiMjMggkCtZcQEZlZWEGg4wMiIscJJwjUXkJEZEaxBoGZfcLMXjKzzWb2dTOrMbMlZvaEmW2PblvirGGS2kuIiMwstiAws07gY8B6d18DlAM3Ap8CNrr7ucDG6HGs1F5CRGR2cU8NVQAJM6sAaoG9wLXAA9HzDwDXxVzDVHsJXVUsInKc2ILA3fcAXwB2Ab3AoLs/Dixz997oNb3A0pneb2a3mVmPmfX09/efUi39uoZARGRWcU4NtZDf+j8bWAHUmdlNC32/u9/r7uvdfX17e/sp1aL2EiIis4tzauhq4DV373f3DPAd4C1An5l1AES3+2OsAVB7CRGRucQZBLuAy8ys1swMeDuwFfgecHP0mpuBR2KsAZhqL6GzhkREjhdbvwV3f8bMvgU8B2SBfwHuBeqBh83sFvJhcUNcNUw60l6iqjzuPyUisujE2njH3e8C7jpmcYr83kHBTLaXyO+YiIjIdEFcWaz2EiIiswsjCNReQkRkVkEEgdpLiIjMruSDQO0lRETmVvJBoPYSIiJzK/kgUHsJEZG5lX4QqL2EiMicSj4I1F5CRGRuAQSB2kuIiMyl9INA7SVEROZU8kHQr/YSIiJzKvkgUHsJEZG5lX4QDOtiMhGRuZR8EPQnU7qGQERkDiUdBGovISIyv5IOArWXEBGZX0kHwWR7Ce0RiIjMrrSDYFh9hkRE5lPSQaD2EiIi8yvxIFB7CRGR+ZR2EKi9hIjIvEo6CM5ZWs+7161QewkRkTlUFLuAON24oZsbN3QXuwwRkTNaSe8RiIjI/BQEIiKBUxCIiAROQSAiEjgFgYhI4BQEIiKBUxCIiAROQSAiEjhz92LXMC8z6wd2nuTb24CB01jOYhPy+mvdwxXy+k9f95Xu3j7fGxZFEJwKM+tx9/XFrqNYQl5/rXuY6w5hr//JrLumhkREAqcgEBEJXAhBcG+xCyiykNdf6x6ukNf/hNe95I8RiIjI3ELYIxARkTkoCEREAlfSQWBm7zSzbWb2ipl9qtj1FJKZ7TCzF83seTPrKXY9cTOz+81sv5ltnrZsiZk9YWbbo9uWYtYYl1nW/W4z2xN9/s+b2TXFrDEuZnaWmf3QzLaa2Utm9vFoeSif/Wzrf0Kff8keIzCzcuBl4B3AbuAXwPvdfUtRCysQM9sBrHf3IC6qMbOrgCTwoLuviZZ9Hjjo7p+NNgRa3P0Pi1lnHGZZ97uBpLt/oZi1xc3MOoAOd3/OzBqAZ4HrgP9AGJ/9bOv/m5zA51/KewQbgFfc/VV3TwP/AFxb5JokJu7+FHDwmMXXAg9E9x8g/x+k5Myy7kFw9153fy66PwxsBToJ57Ofbf1PSCkHQSfwy2mPd3MS/0CLmAOPm9mzZnZbsYspkmXu3gv5/zDA0iLXU2gfNbNN0dRRSU6NTGdmq4BLgGcI8LM/Zv3hBD7/Ug4Cm2FZac6Dzexyd78U+HXgI9H0gYTjr4HXARcDvcAXi1pNzMysHvg2cIe7DxW7nkKbYf1P6PMv5SDYDZw17XEXsLdItRScu++NbvcD3yU/VRaavmgOdXIudX+R6ykYd+9z95y7TwD3UcKfv5lVkh8Ev+bu34kWB/PZz7T+J/r5l3IQ/AI418zONrMq4Ebge0WuqSDMrC46cISZ1QG/Bmye+10l6XvAzdH9m4FHilhLQU0OgpHrKdHP38wM+Aqw1d3/fNpTQXz2s63/iX7+JXvWEEB0ytSXgHLgfnf/THErKgwzW01+LwCgAnio1NfdzL4OvJV8C94+4C7gH4GHgW5gF3CDu5fcQdVZ1v2t5KcFHNgB3D45Z15KzOwK4MfAi8BEtPiPyM+Th/DZz7b+7+cEPv+SDgIREZlfKU8NiYjIAigIREQCpyAQEQmcgkBEJHAKAhGRwCkIRAAzy03r1Pj86exWa2arpncGFTnTVBS7AJEzxJi7X1zsIkSKQXsEInOIvtfhc2b28+jnnGj5SjPbGDX12mhm3dHyZWb2XTN7Ifp5S/Srys3svqhn/ONmlijaSokcQ0Egkpc4ZmrofdOeG3L3DcBfkr9Snej+g+5+EfA14MvR8i8DT7r7OuBS4KVo+bnA/3T3C4HDwHtjXRuRE6Ari0UAM0u6e/0My3cA/8bdX42ae+1z91YzGyD/hSCZaHmvu7eZWT/Q5e6pab9jFfCEu58bPf5DoNLd/6wAqyYyL+0RiMzPZ7k/22tmkpp2P4eOz8kZREEgMr/3Tbv9WXT/p+Q72gJ8APhJdH8j8GHIf12qmTUWqkiRk6WtEpG8hJk9P+3xP7v75Cmk1Wb2DPkNp/dHyz4G3G9mvw/0Ax+Kln8cuNfMbiG/5f9h8l8MInLG0jECkTlExwjWu/tAsWsRiYumhkREAqc9AhGRwGmPQEQkcAoCEZHAKQhERAKnIBARCZyCQEQkcP8fxJElZkdQ9R4AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(loss_hist[\"train accuracy\"])\n",
    "plt.xlabel(\"Epoch\")\n",
    "plt.ylabel(\"Loss\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYkAAAEGCAYAAACQO2mwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAg20lEQVR4nO3deXRcZ5nn8e9Tq6TSYskq2bK8xYmT4IQsoAlLaJoOkJg0TdLNBMKwuJmck57u0E1ObyQzf/Q2mcM000xPmoHpAOk2w5L20IS42YOBhBCIcRKHxHZCjPdYtiXZsmxtpSo980ddyWVJJcvLVUm6v885de69b90qPZc64ef3vve919wdERGRycQqXYCIiMxeCgkRESlLISEiImUpJEREpCyFhIiIlJWodAHno7m52VeuXFnpMkRE5pSnn366y92z09l3TofEypUr2bJlS6XLEBGZU8xs73T31ekmEREpSyEhIiJlKSRERKQshYSIiJSlkBARkbIUEiIiUpZCQkREyopkSLzSM8Anv/cSe7r6Kl2KiMisFsmQ6OnPcf8PdvLiod5KlyIiMqtFMiSydWkAOk8MVbgSEZHZLZIhsTCTJmYKCRGRM4lkSMRjxsLaNJ0nFRIiIlOJZEgAZGvTHOlVSIiITCW6IVGnnoSIyJlEOyQ0JiEiMqXIhkRLXZquk0OMjHilSxERmbUiGxLZujTDBadnYLjSpYiIzFqRDgnQZbAiIlMJNSTMbIGZfdXMXjSzHWb2BjNrMrNHzezlYNlYsv+9ZrbTzF4ys5vCrC1bq5AQETmTsHsS/wv4jrtfDlwN7ADuATa5+2pgU7CNma0BbgeuANYCnzazeFiFtdRXAdB5cjCsPyEiMueFFhJmVg+8Gfg8gLvn3L0HuAVYH+y2Hrg1WL8FeMjdh9x9N7ATuC6s+kZPN2muhIhIeWH2JFYBncA/mdmzZvY5M8sAi9y9AyBYtgT7twH7Sz5/IGg7jZndaWZbzGxLZ2fnOReXScWpTsZ1uklEZAphhkQCeA3wGXe/FugjOLVUhk3SNuH6VHd/wN3b3b09m82ec3Fmpgl1IiJnEGZIHAAOuPtTwfZXKYbGYTNrBQiWR0r2X1by+aXAwRDro0UT6kREphRaSLj7IWC/mV0WNL0V2A5sBNYFbeuAR4L1jcDtZpY2s4uA1cDmsOqD4rjEEYWEiEhZiZC//w+BL5lZCtgFfJhiMG0wszuAfcBtAO6+zcw2UAySPHCXuxfCLC5bl+bJX3WH+SdEROa0UEPC3bcC7ZO89dYy+98H3BdmTaWytWmODwwzlC+QToR2ta2IyJwV2RnXAC31xctgu07mKlyJiMjsFOmQODVXQhPqREQmE+2QqA1mXWvwWkRkUtEOidGb/GmuhIjIpCIdEgtrU5ipJyEiUk6kQyIZj9FUk9JcCRGRMiIdEqDHmIqITEUhoZAQESlLIaGQEBEpSyERhIT7hBvOiohEnkKiNk2uMELvQL7SpYiIzDoKibG5Epp1LSIyXuRDoqWuOOtajzEVEZko8iGhWdciIuUpJEZDQlc4iYhMEPmQqK9KkErEFBIiIpOIfEiYGS16jKmIyKQiHxKgCXUiIuUoJCjOlVBIiIhMpJAg6Eno6iYRkQkUEhTnShzty5HLj1S6FBGRWUUhwanLYLv71JsQESmlkEBzJUREylFIoJAQESkn1JAwsz1m9ryZbTWzLUFbk5k9amYvB8vGkv3vNbOdZvaSmd0UZm2lWoKQ0FwJEZHTzURP4jfc/Rp3bw+27wE2uftqYFOwjZmtAW4HrgDWAp82s/gM1MfC2hSgnoSIyHiVON10C7A+WF8P3FrS/pC7D7n7bmAncN1MFJROxFlQk1RIiIiME3ZIOPA9M3vazO4M2ha5ewdAsGwJ2tuA/SWfPRC0ncbM7jSzLWa2pbOz84IVqgl1IiITJUL+/uvd/aCZtQCPmtmLU+xrk7RNeKaouz8APADQ3t5+wZ452lKf5sgJPXhIRKRUqD0Jdz8YLI8AD1M8fXTYzFoBguWRYPcDwLKSjy8FDoZZX6lsrWZdi4iMF1pImFnGzOpG14EbgReAjcC6YLd1wCPB+kbgdjNLm9lFwGpgc1j1jTd6kz/3C9Y5ERGZ88I83bQIeNjMRv/Ol939O2b2c2CDmd0B7ANuA3D3bWa2AdgO5IG73L0QYn2nydalGRwe4eRQnrqq5Ez9WRGRWS20kHD3XcDVk7R3A28t85n7gPvCqmkqY8+6PjGkkBARCWjGdUCzrkVEJlJIBBQSIiITKSQC2VqFhIjIeAqJwIKaJMm46f5NIiIlFBIBM9OsaxGRcRQSJfQYUxGR0ykkSoxOqBMRkSKFRIliSOj+TSIioxQSJbJ1VXT35cgXRipdiojIrKCQKJGtS+MOR/tylS5FRGRWUEiUGJ0roctgRUSKFBIlNOtaROR0CokSLQoJEZHTKCRKjPUkNFdCRARQSJymKhmnriqhnoSISEAhMU62Ts+6FhEZpZAYp0WzrkVExigkxsnWVSkkREQCColxdCdYEZFTFBLjZOvS9OUK9A3lK12KiEjFKSTG0VwJEZFTFBLjaK6EiMgpColxdGsOEZFTQg8JM4ub2bNm9o1gu8nMHjWzl4NlY8m+95rZTjN7ycxuCru2yYyGxJFezZUQEZmJnsRHgR0l2/cAm9x9NbAp2MbM1gC3A1cAa4FPm1l8Buo7TVNNinjMdLpJRISQQ8LMlgK/CXyupPkWYH2wvh64taT9IXcfcvfdwE7gujDrm0wsZjTXpnS6SUSE8HsSfw/8OVD6qLdF7t4BECxbgvY2YH/JfgeCttOY2Z1mtsXMtnR2doZStJ51LSJSFFpImNk7gSPu/vR0PzJJm09ocH/A3dvdvT2bzZ5XjeVka9N68JCICJAI8buvB95lZjcDVUC9mX0ROGxmre7eYWatwJFg/wPAspLPLwUOhlhfWS11VWw72FuJPy0iMquE1pNw93vdfam7r6Q4IP0Dd/8AsBFYF+y2DngkWN8I3G5maTO7CFgNbA6rvqlk69J09+UojEzoyIiIREqYPYlyPg5sMLM7gH3AbQDuvs3MNgDbgTxwl7sXKlAf2bo0hRHnWH+O5uC51yIiUTQjIeHuPwJ+FKx3A28ts999wH0zUdNUTs2VGFJIiEikacb1JFp0aw4REUAhMSndmkNEpEghMYnRU0wKCRGJOoXEJDLpBJlUXM+6FpHIU0iU0VKvx5iKiCgkytBjTEVEFBJlZevSurpJRCJvWiFhZhkziwXrl5rZu8wsGW5plZWtS9PZq5AQkWibbk/icaDKzNooPgPiw8A/h1XUbJCtS3NiKM9AriKTvkVEZoXphoS5ez/wO8A/uPtvA2vCK6vyRudKdOmUk4hE2LRDwszeALwf+GbQVon7Ps2YsVtzaPBaRCJsuiFxN3Av8HBwI75VwA9Dq2oWyI5NqNNcCRGJrmn1Btz9MeAxgGAAu8vd/yjMwiqtpV6zrkVEpnt105fNrN7MMhRv5f2Smf1ZuKVV1sJMmpgpJEQk2qZ7ummNu/cCtwLfApYDHwyrqNkgHjOaMporISLRNt2QSAbzIm4FHnH3YSZ5/vR8k61Lc0RzJUQkwqYbEv8I7AEywONmtgKY9w+BbtGsaxGJuGmFhLvf7+5t7n6zF+0FfiPk2iouW6f7N4lItE134LrBzD5pZluC199R7FXMa9m6NF0nhxgZmfdn1kREJjXd000PAieA9wSvXuCfwipqtsjWphkuOD0Dw5UuRUSkIqY7a/pid393yfZfmdnWEOqZVUrnSjRlUhWuRkRk5k23JzFgZm8a3TCz64GBcEqaPbJ6jKmIRNx0exL/CfiCmTUE28eAdeGUNHuM3r+p86RuzSEi0TTd23I8B1xtZvXBdq+Z3Q38IsTaKm7sJn+aKyEiEXVWT6Zz995g5jXAH0+1r5lVmdlmM3vOzLaZ2V8F7U1m9qiZvRwsG0s+c6+Z7TSzl8zsprM+mgusNp2gKhnT6SYRiazzeXypneH9IeAGd78auAZYa2avB+4BNrn7aooPMLoHwMzWALcDVwBrgU+bWfw86jtvZkZLXZUm1IlIZJ1PSEw5eSCYdHcy2EwGLwduAdYH7esp3uqDoP0hdx9y993ATuC686jvgtCEOhGJsilDwsxOmFnvJK8TwJIzfbmZxYNLZY8Aj7r7U8Aid+8ACJYtwe5twP6Sjx8I2sZ/552jk/o6Ozunc4znJVub1oOHRCSypgwJd69z9/pJXnXufsZBb3cvuPs1wFLgOjO7cordJzt9NaG34u4PuHu7u7dns9kzlXDe1JMQkSg7n9NN0+buPcCPKI41HDazVoBgeSTY7QCwrORjS4GDM1HfVFrq0hwfGGYoX6h0KSIiMy60kDCzrJktCNargbcBLwIbOTXHYh3wSLC+EbjdzNJmdhGwGtgcVn3TNXoZbNfJXIUrERGZedOdTHcuWoH1wRVKMWCDu3/DzH4KbDCzO4B9wG0AwbOzN1B88l0euMvdK/7P91NzJQZpW1Bd4WpERGZWaCHh7r8Arp2kvRt4a5nP3AfcF1ZN52Js1rXGJUQkgmZkTGIua6mrAtBcCRGJJIXEGSysLd79VT0JEYkihcQZJOMxmjIpzZUQkUhSSExDtlZzJUQkmhQS09BSr5AQkWhSSEyDehIiElUKiWkYvTWH+5T3NBQRmXcUEtOQrUuTK4zQO5CvdCkiIjNKITENeoypiESVQmIa9BhTEYkqhcQ0tIz1JBQSIhItColpyNYGt+bQFU4iEjEKiWmor06QSsQUEiISOQqJaTAzPcZURCJJITFNeoypiESRQmKaFBIiEkUKiWlqqUtz+MSgZl2LSKQoJKbp2uWN9PQPs/7JPZUuRURkxigkpundr2njhstb+G/fepHtB3srXY6IyIxQSEyTmfGJf38VC2qS/OFXnqE/p/s4icj8p5A4Cwtr0/zP917Drq4+/uYb2ytdjohI6BQSZ+n6S5r5vTdfzFc27+dbz3dUuhwRkVApJM7Bn9x4KVcvbeCef/0Fr/QMVLocEZHQhBYSZrbMzH5oZjvMbJuZfTRobzKzR83s5WDZWPKZe81sp5m9ZGY3hVXb+UrGY9z/vmsZcbj7oWfJF0YqXZKISCjC7EnkgT9x91cBrwfuMrM1wD3AJndfDWwKtgneux24AlgLfNrM4iHWd15WLMzwN7dewc/3HONTP9xZ6XJEREIRWki4e4e7PxOsnwB2AG3ALcD6YLf1wK3B+i3AQ+4+5O67gZ3AdWHVdyH89rVL+Z1r27h/08ts3n200uWIiFxwMzImYWYrgWuBp4BF7t4BxSABWoLd2oD9JR87ELSN/647zWyLmW3p7OwMte7p+Otbr2RZUw13P/Qsx/uHK12OiMgFFXpImFkt8K/A3e4+1Sw0m6Rtwj0w3P0Bd2939/ZsNnuhyjxntekE999+LUdODHHP136h23aIyLwSakiYWZJiQHzJ3b8WNB82s9bg/VbgSNB+AFhW8vGlwMEw67tQrl62gD+96TK+/cIhHvr5/jN/QERkjgjz6iYDPg/scPdPlry1EVgXrK8DHilpv93M0mZ2EbAa2BxWfRfanb+2ijdd0sxf/ds2dh45UelyREQuiDB7EtcDHwRuMLOtwetm4OPA283sZeDtwTbuvg3YAGwHvgPc5e6FEOu7oGIx45PvuZqaVIKPfPlZBofnTOkiImXZXD6H3t7e7lu2bKl0Gaf5wYuH+Y//vIXffeNK/vJdV1S6HBGRCczsaXdvn86+mnF9gd1w+SI+fP1K/vnJPWzacbjS5YiInBeFRAjuecflvKq1nj/76i840jtY6XJERM6ZQiIE6UScf3jfNfTn8nzowc3sP9pf6ZJERM6JQiIkl7TU8dkPtXOwZ4Df+tQTPPFyV6VLEhE5awqJEP3a6iwbP/ImWurSfOjBp/js47s02U5E5hSFRMhWNmd4+A+u58Y1i7nvWzu4+1+2MpDT5bEiMjcoJGZAJp3gMx94DX9646VsfO4g7/7Mkxw4pnEKEZn9FBIzxMz4yA2r+fy6dvYf6+ddn/oJT/5K4xQiMrspJGbYDZcv4pG7rqcpk+KDn9/Mg0/s1jiFiMxaCokKWJWt5eE/eCM3XN7CX39jO3/y/57TbTxEZFZSSFRIXVWSf/zAa7n7bav52jOvcNv/+SkH9bxsEZllFBIVFIsZd7/tUj77oXZ2d/XxW//wBE/t6q50WSIiYxQSs8Db1yzi63ddT0N1kvd/7in+eMNWntl3TGMVIlJxugvsLNI7OMwnvvMSX3vmAH25Amta6/nA61dwyzVLyKQTlS5PROaJs7kLrEJiFjo5lOfrz77CF3+2lxcPnaA2neB3XtPG+1+3gssW11W6PBGZ4xQS84S788y+Y3zxZ/v45vMd5PIj/LuVjXzg9StYe+Vi0ol4pUsUkTlIITEPHe3L8dWn9/Olp/axt7ufhZkUt7Uv4/2vW86ypppKlycic4hCYh4bGXGe2NnFF3+2l+/vOIxTvJHgrdcs4cYrFlOrsQsROQOFRER0HB/goc37+erTB3ilZ4B0IsZbX9XCu65ewlsua6EqqdNRIjKRQiJiRscuNm49yDef76DrZI66dIIbr1jMb13dyvWXNJOM62pnESlSSERYvjDCT3d182/PHeTbLxzixGCepkyKm1+9mHdd3Ub7ikZiMat0mSJSQQoJAWAoX+CxlzrZ+NxBvr/jMIPDIyxpqOKdVy/hnVe18uq2BswUGCJRo5CQCfqG8nx/x2E2bj3I4y93Mlxw2hZUc/OrF7P2ylauXbZAPQyRiFBIyJR6+nM8uv0w33nhED9+uYtcYYTF9VWsvXIx77hyMe0rm4grMETmrVkREmb2IPBO4Ii7Xxm0NQH/AqwE9gDvcfdjwXv3AncABeCP3P27Z/obConz1zs4zA92HOFbz3fw2C87GcqP0Fyb5qYrFnHzq1t53UVNJDToLTKvzJaQeDNwEvhCSUj8LXDU3T9uZvcAje7+MTNbA3wFuA5YAnwfuNTdp3zIgkLiwuobyvPDl47w7ecP8YMXjzAwXKCxJsmNaxaz9tWLee2KRuqrkpUuU0TO09mERGgzr9z9cTNbOa75FuAtwfp64EfAx4L2h9x9CNhtZjspBsZPw6pPJsqkE7zzqiW886olDOQKPPbLTr79QgfffL6Df9myH4DlTTWsaa1nzZL6sWVrQ5UGwEXmqZmenrvI3TsA3L3DzFqC9jbgZyX7HQjaJjCzO4E7AZYvXx5iqdFWnYqz9srFrL1yMYPDBZ7afZQXXjnO9oO9bO/o5bvbDzHaCV1QkywGxmh4LKnn4myt5maIzAOz5R4Ok/0zdNLzYO7+APAAFE83hVmUFFUl4/z6pVl+/dLsWNvJoTwvHeodC43tB3v5vz/by1B+BIBUPMblrXW84eKFvOmSZtpXNFGd0gxwkblmpkPisJm1Br2IVuBI0H4AWFay31Lg4AzXJmehNp3gtSuaeO2KprG2fGGEPd19bAuC49l9PTz4xG7+8bFdpOIxXruikTetbub6S5p5dVuDrqASmQNCvQQ2GJP4RsnA9SeA7pKB6yZ3/3MzuwL4MqcGrjcBqzVwPff15/Js3n2Un+zs4omd3ezo6AWgrirBG1YtHAuNVc0ZjWuIzJBZMXBtZl+hOEjdbGYHgL8APg5sMLM7gH3AbQDuvs3MNgDbgTxw15kCQuaGmlSCt1zWwlsuKw4/dZ0c4qe/6uYnO7v48ctdfG/7YQBaG6p448XNXLt8ARc1Z1jZnKG1vkoT/EQqTJPppGLcnX1H+/nJzmJo/ORXXfT0D4+9n0rEWN5Uw8qFNaxcmGFFc2ZsfcmCap2uEjlHs6InIXImZsaKhRlWLMzwH163nJERp6N3kL1dfezp7mdvdx97uvvY09XPEzu7GBweGftsMm4sayoGxkXNGVZlM6xqruXibIZsXVqnrkQuEIWEzBqxmNG2oJq2BdW88ZLT3xsZcY6cGGJPdx97u/vY3dUfLPt48lenB0htOhGERoZV2drTQkRXWImcHYWEzAmxmLG4oYrFDVW8ftXC094b7YHs6jzJrs6+4rKrj5/vOcbXt55+kdyShipWZWtZsbAmeGWKy6aMAkRkEgoJmfNKeyC/tjp72nsDuQK7u/rY1XUqQHZ39fHN5ztOG/8AaKlLs3JhhuULi+MgpQHSUKPbkUg0KSRkXqtOxcdmgY93vH+YvUeL4x/7ukeX/Tz+y06+emLotH3rqxK0NdbQtqCKtgXVLAlebY3FcMrWpnUllsxLCgmJrIaaJFfVLOCqpQsmvNefy7PvaD97gwH0A8cGeOXYAAeODfDU7qOcGMyftn8ybrQ2VI8FSNuCKpoyKeqrk9RXJYvL6sTYeiYV1+C6zAkKCZFJ1KQSXL64nssXT+yBQPEW6x09g7zS088rPYMc7CmGyMGeAZ78VReHewcZmeLq8phREiDF8GioTrKovorWYOyltaGa1oYqWurTpBMaL5HKUEiInIP6qiT1i5Nctrhu0vfzhRFODObpHRymd2B0OczxgeEJbb2DeXoHhvnl4RM8/stO+nIT55E216aKA/f11SUhUkW2Lk1jTYoFNUkW1KTUQ5ELTiEhEoJEPEZjJkVjJnXWnz0xOMyh44N0HB88tewd5NDxAQ4c62fL3qMTBt1HJeNGQ3WKxprkWHAsqE7SmEnRUJ2ksSZFW2M1q5oztC2o1jiKnJFCQmSWqatKUleVZPWiyXspULxq61DvIJ0nhujpz9HTP0zPQI5j/cP09A9zfCDHsb5hDhwb4IVXjtPTP8zA8Ok9lHQidtockrH1bC0N1bqaS4oUEiJzUHUqzkXNxdnm0zU4XOBYf4593f3s6uobm1eyo+ME3912mELJIEpzbYpVzbWsyhYvCa5JxkkmYqTiMVKly0SM5Pi2eIyaVJzGTErPFJkHFBIiEVGVjAeD4dW8btyExFx+hH1H+8fmkezqLM4teXT7Ybr7cuf8N+urEiysTbMwk6Ipk2JhbXHZlDnVNtpek0wQjxtxM+IxIxEznQ6bBRQSIkIqEeOSllouaamd8F5/Ls/g8AjDhRFy+RFyo8v8xLbhgpMrFDg5VODoyRxH+4bo7stxtC/HvqP9PLu/h2N9OfJTXfpVwoxiWFgxNOIxIxGPETMjFTcaalKnh00mRVNtioWZ9FggLcykqK9KKnDOkUJCRKZUk0pQc/bj72WNjDi9g8Nj4dF9srgcyhcojDj5EacQvIrrI8VlwSn4qfZcfoSe/hzdfTn2H+vn6MkcJ4byk/7NeMxorCkO6NdVJaitSlKXThTX0wlqqxLFsaCx9WJ7XVWSTDpOTTJBVap4Ki1qV48pJERkRsViVrzqqibFxdkz7382hvIFjvUN0903NBZAxTAqbh/rG6Yvl+f4wDCvHOvn5FCeE4N5+ie57Hgy8ZhRk4xTnYpTk4pTnUpQM7pe0p6KF8dqEvEYqbiRjMdIJmIkYjY2jlN8Fd9LxWNUJeNUp4rLqmTx+0aX6USsYj0hhYSIzBvpRJzFDXEWN1Sd1ecKIx4ExjAnh/KcHMyPzXMZyBXozxUYGC7QnysGSmnbQK5A31CezhNDY9vDhdFTb8VTchfisT3pRIzqVJyqRDGM3vaqFv7Lb645/y8+A4WEiERePGY0VCdDufTXS0+RFUYYDsZuhgvFsZx8wRnKFxgcHmFguMBg8BrIBcvhkVNtw6faFjdUX/BaJ6OQEBEJkZmRiBuJePEKs7lGFzGLiEhZCgkRESlLISEiImUpJEREpCyFhIiIlKWQEBGRshQSIiJSlkJCRETKMr8Q88UrxMw6gb3n8RXNQNcFKmeu0bFHV5SPP8rHDqeOf4W7T+vOWXM6JM6XmW1x9/ZK11EJOvZoHjtE+/ijfOxwbsev000iIlKWQkJERMqKekg8UOkCKkjHHl1RPv4oHzucw/FHekxCRESmFvWehIiITEEhISIiZUUyJMxsrZm9ZGY7zeyeStcz08xsj5k9b2ZbzWxLpesJk5k9aGZHzOyFkrYmM3vUzF4Olo2VrDFMZY7/L83sleD332pmN1eyxrCY2TIz+6GZ7TCzbWb20aB93v/+Uxz7Wf/2kRuTMLM48Evg7cAB4OfA+9x9e0ULm0Fmtgdod/d5P6nIzN4MnAS+4O5XBm1/Cxx1948H/0hodPePVbLOsJQ5/r8ETrr7/6hkbWEzs1ag1d2fMbM64GngVuB3mee//xTH/h7O8rePYk/iOmCnu+9y9xzwEHBLhWuSkLj748DRcc23AOuD9fUU/+OZl8ocfyS4e4e7PxOsnwB2AG1E4Pef4tjPWhRDog3YX7J9gHP8H28Oc+B7Zva0md1Z6WIqYJG7d0DxPyagpcL1VMJHzOwXwemoeXe6ZTwzWwlcCzxFxH7/cccOZ/nbRzEkbJK2aJ1zg+vd/TXAO4C7glMSEh2fAS4GrgE6gL+raDUhM7Na4F+Bu929t9L1zKRJjv2sf/sohsQBYFnJ9lLgYIVqqQh3PxgsjwAPUzwFFyWHg3O2o+duj1S4nhnl7ofdveDuI8Bnmce/v5klKf6f5Jfc/WtBcyR+/8mO/Vx++yiGxM+B1WZ2kZmlgNuBjRWuacaYWSYYyMLMMsCNwAtTf2re2QisC9bXAY9UsJYZN/p/kIHfZp7+/mZmwOeBHe7+yZK35v3vX+7Yz+W3j9zVTQDBZV9/D8SBB939vspWNHPMbBXF3gNAAvjyfD5+M/sK8BaKt0g+DPwF8HVgA7Ac2Afc5u7zcnC3zPG/heLpBgf2AL83eo5+PjGzNwE/Bp4HRoLm/0zx3Py8/v2nOPb3cZa/fSRDQkREpieKp5tERGSaFBIiIlKWQkJERMpSSIiISFkKCRERKUshIXIGZlYouWvm1gt552AzW1l6h1aR2SZR6QJE5oABd7+m0kWIVIJ6EiLnKHgux383s83B65KgfYWZbQpuorbJzJYH7YvM7GEzey54vTH4qriZfTa47//3zKy6YgclMo5CQuTMqsedbnpvyXu97n4d8CmKs/gJ1r/g7lcBXwLuD9rvBx5z96uB1wDbgvbVwP929yuAHuDdoR6NyFnQjGuRMzCzk+5eO0n7HuAGd98V3EztkLsvNLMuig98GQ7aO9y92cw6gaXuPlTyHSuBR919dbD9MSDp7v91Bg5N5IzUkxA5P15mvdw+kxkqWS+gsUKZRRQSIufnvSXLnwbrT1K8uzDA+4EngvVNwO9D8TG6ZlY/U0WKnCv9i0XkzKrNbGvJ9nfcffQy2LSZPUXxH1zvC9r+CHjQzP4M6AQ+HLR/FHjAzO6g2GP4fYoPfhGZtTQmIXKOgjGJdnfvqnQtImHR6SYRESlLPQkRESlLPQkRESlLISEiImUpJEREpCyFhIiIlKWQEBGRsv4/2bdd1JaCr4sAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(loss_hist[\"train loss\"])\n",
    "plt.xlabel(\"Epoch\")\n",
    "plt.ylabel(\"Loss\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Accuracy%:  98.51 == 9851 / 10000\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    model.eval()\n",
    "    \n",
    "    y_true_test = []\n",
    "    y_pred_test = []\n",
    "    \n",
    "    for batch_idx, (img, labels) in enumerate(testloader):\n",
    "        img = img.to(device)\n",
    "        label = label.to(device)\n",
    "    \n",
    "        preds = model(img)\n",
    "        \n",
    "        y_pred_test.extend(preds.detach().argmax(dim=-1).tolist())\n",
    "        y_true_test.extend(labels.detach().tolist())\n",
    "        \n",
    "    total_correct = len([True for x, y in zip(y_pred_test, y_true_test) if x==y])\n",
    "    total = len(y_pred_test)\n",
    "    accuracy = total_correct * 100 / total\n",
    "    \n",
    "    print(\"Test Accuracy%: \", accuracy, \"==\", total_correct, \"/\", total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}